import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fsrl.utils import DummyLogger, WandbLogger
from tqdm.auto import trange  # noqa

from osrl.common.net import mlp
from torch.distributions.normal import Normal


class State_AE(nn.Module):
    """
    state encoder and decoder
    
    """

    def __init__(self,
                 state_dim: int,
                 encode_dim: int = 32,
                 hidden_sizes: list = [128, 128],
                 linear_only: bool = False):

        super().__init__()
        self.state_dim = state_dim
        self.encode_dim = encode_dim
        
        self.encoder_hidden_sizes = list(hidden_sizes)
        self.decoder_hidden_sizes = list(hidden_sizes)
        self.decoder_hidden_sizes.reverse()
            
        self.encoder_dims = [self.state_dim] + self.encoder_hidden_sizes + [self.encode_dim]
        self.decoder_dims = [self.encode_dim] + self.decoder_hidden_sizes + [self.state_dim]
        if linear_only:
            self.encoder = nn.Linear(self.state_dim, self.encode_dim)
            self.decoder = nn.Linear(self.encode_dim, self.state_dim)
        else:
            self.encoder = mlp(self.encoder_dims, nn.ReLU)
            self.decoder = mlp(self.decoder_dims, nn.ReLU)
        

    def forward(self, state):
        state_enc = self.encoder(state)
        state_dec = self.decoder(state_enc)
        return state_enc, state_dec
    
    def encode(self, state):
        return self.encoder(state)
    
    def decode(self, hidden_state):
        return self.decoder(hidden_state)

class Action_AE(nn.Module):
    """
    action encoder and decoder
    
    """

    def __init__(self,
                 action_dim: int,
                 encode_dim: int = 2,
                 hidden_sizes: list = [32, 32],
                 require_tanh: bool = True,
                 decode_mu_std: bool = False,
                 linear_only: bool = False,
                 decoder_linear_only: bool = False):

        super().__init__()
        self.action_dim = action_dim
        self.encode_dim = encode_dim
        
        self.encoder_hidden_sizes = list(hidden_sizes)
        self.decoder_hidden_sizes = list(hidden_sizes)
        self.decoder_hidden_sizes.reverse()
            
        self.encoder_dims = [self.action_dim] + self.encoder_hidden_sizes + [self.encode_dim]
        self.decoder_dims = [self.encode_dim] + self.decoder_hidden_sizes + [self.action_dim]
        self.decode_mu_std = decode_mu_std
        if linear_only:
            if decoder_linear_only:
                self.encoder = mlp(self.encoder_dims, nn.ReLU)
            else:
                self.encoder = nn.Linear(self.action_dim, self.encode_dim)
            if decode_mu_std:
                self.mu_decoder = nn.Linear(self.encode_dim, self.action_dim)
                self.std_decoder = nn.Linear(self.encode_dim, self.action_dim)
            else:
                self.decoder = nn.Linear(self.encode_dim, self.action_dim)
        else:
            if require_tanh:
                self.encoder = mlp(self.encoder_dims, nn.ReLU, nn.Tanh)
                if decode_mu_std:
                    self.mu_decoder = mlp(self.decoder_dims, nn.ReLU, nn.Tanh)
                    self.std_decoder = mlp(self.decoder_dims, nn.ReLU, nn.Tanh)
                else:
                    self.decoder = mlp(self.decoder_dims, nn.ReLU, nn.Tanh)
            else:
                self.encoder = mlp(self.encoder_dims, nn.ReLU)
                if decode_mu_std:
                    self.mu_decoder = mlp(self.decoder_dims, nn.ReLU)
                    self.std_decoder = mlp(self.decoder_dims, nn.ReLU)
                else:
                    self.decoder = mlp(self.decoder_dims, nn.ReLU)
        

    def forward(self, action, add_noise=False, noise_scale=0.1):
        action_enc = self.encoder(action)
        if add_noise:
            noise = torch.randn_like(action_enc) * noise_scale
            action_dec = self.decoder(action_enc + noise)
        else:
            action_dec = self.decoder(action_enc)
        return action_enc, action_dec
    
    def encode(self, action):
        return self.encoder(action)
    
    def decode(self, hidden_mu, hidden_std=None):
        if self.decode_mu_std and hidden_std is not None:
            return self.mu_decoder(hidden_mu), self.std_decoder(hidden_std)
        return self.decoder(hidden_mu)

class inverse_dynamics_model(nn.Module):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 hidden_sizes: list = [128, 128]):

        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        self.hidden_sizes = list(hidden_sizes)
        self.mlp_dims = [self.state_dim * 2] + self.hidden_sizes + [self.action_dim]
            
        self.mlp = mlp(self.mlp_dims, nn.ReLU, nn.Tanh)
        

    def forward(self, state, next_state):
        mlp_input = torch.cat([state, next_state], dim=-1)
        action_dec = self.mlp(mlp_input)
        return action_dec

class ActionAETrainer:
    """
    action encoder decoder trainer

    """

    def __init__(
            self,
            model,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            device="cpu",
            add_noise: bool = False,
            noise_scale: float = 0.1) -> None:
        self.model = model
        self.logger = logger
        self.device = device
        self.add_noise = add_noise
        self.noise_scale = noise_scale

        self.optim = torch.optim.Adam(
            self.model.parameters(),
            lr=learning_rate
        )

    def train_one_step(self, action):
        # True value indicates that the corresponding key value will be ignored
        action_enc, action_dec = self.model(
            action,
            add_noise = self.add_noise,
            noise_scale = self.noise_scale
        )
        recon_loss = F.mse_loss(action_dec, action, reduction="mean")
        # if self.logprob_loss:
        #     rtg_loss = -rtg_dis.log_prob(rtg).mean()
        # else:
        #     rtg_loss = F.mse_loss(rtg_preds, rtg, reduction="mean")
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()
        loss = recon_loss

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        self.logger.store(
            tab="train",
            action_recon_loss=recon_loss.item(),
        )
        return recon_loss.item()

    def eval_one_step(self, action):
        # True value indicates that the corresponding key value will be ignored
        self.model.eval()
        action_enc, action_dec = self.model(
            action
        )
        recon_loss = F.mse_loss(action_dec, action, reduction="mean")
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        self.logger.store(
            tab="eval",
            action_recon_loss=recon_loss.item(),
        )
        self.model.train()
        return recon_loss.item()

class StateAETrainer:
    """
    action encoder decoder trainer

    """

    def __init__(
            self,
            model,
            inverse_dynamics_model,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            device="cpu",
            idm_loss_weight: float = 1.0) -> None:
        self.model = model
        self.inverse_dynamics_model = inverse_dynamics_model
        self.logger = logger
        self.device = device
        self.idm_loss_weight = idm_loss_weight

        self.optim = torch.optim.Adam(
            [{'params': self.model.parameters()},{'params': self.inverse_dynamics_model.parameters()}],
            lr=learning_rate
        )

    def train_one_step(self, state, action, next_state):
        state_enc, state_dec = self.model(
            state
        )
        next_state_enc, next_state_dec = self.model(
            next_state
        )
        action_dec = self.inverse_dynamics_model(
            state_enc,
            next_state_enc
        )
        state_recon_loss = F.mse_loss(state_dec, state, reduction="mean")
        next_state_recon_loss = F.mse_loss(next_state_dec, next_state, reduction="mean")
        inverse_dynamics_loss = F.mse_loss(action_dec, action, reduction="mean")
        loss = state_recon_loss + next_state_recon_loss + self.idm_loss_weight * inverse_dynamics_loss

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        self.logger.store(
            tab="train",
            state_recon_loss=state_recon_loss.item()+next_state_recon_loss.item(),
            idm_loss=inverse_dynamics_loss.item(),
            total_loss=loss.item()
        )
        return loss.item()

    def eval_one_step(self, state, action, next_state):
        # True value indicates that the corresponding key value will be ignored
        self.model.eval()
        state_enc, state_dec = self.model(
            state
        )
        next_state_enc, next_state_dec = self.model(
            next_state
        )
        action_dec = self.inverse_dynamics_model(
            state_enc,
            next_state_enc
        )
        state_recon_loss = F.mse_loss(state_dec, state, reduction="mean")
        next_state_recon_loss = F.mse_loss(next_state_dec, next_state, reduction="mean")
        inverse_dynamics_loss = F.mse_loss(action_dec, action, reduction="mean")
        loss = state_recon_loss + next_state_recon_loss + self.idm_loss_weight * inverse_dynamics_loss
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        self.logger.store(
            tab="eval",
            state_recon_loss=state_recon_loss.item()+next_state_recon_loss.item(),
            idm_loss=inverse_dynamics_loss.item(),
            total_loss=loss.item()
        )
        self.model.train()
        return loss.item()
